#
# Estimate SE with bootstraps
#

#### Set-Up ####

rm(list = ls())
Sys.info()[7]
# Set file paths 

DATA <- "../external_links/real"
TEMP <- "../temp"
RESULTS <- "../output"


# Go to code folder
setwd("./")
sink("bootstrap_se.txt")

folder <- paste0(TEMP, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}
folder <- paste0(RESULTS, "/estimate_w_xs")
if (file.exists(folder)) {
  cat("The folder already exists")
} else {
  dir.create(folder)
}

Sys.info()[7]
Sys.time()

# Packages
library(MASS)
library(tmvtnorm)
library(haven)
library(dplyr)
library(tidyr)
library(nloptr)
library(corpcor)
library(stargazer)
library(pracma)
library(foreign)
library(writexl)
library(broom)
library(png)
library(tidyverse)


print("Begin time")
start_time <- Sys.time()
print(start_time)


#### Load Data + Assumptions ####
  analysis <- read_dta(paste0(DATA, "/model_sample.dta")) %>%
    mutate(oop_i = ifelse(is.na(oop_i)==TRUE, Inf, oop_i)) %>%
    mutate(age_25_34 = ifelse(age >= 25 & age < 35, 1, 0)) %>%
    mutate(age_35_plus = ifelse(age>35,1,0)) %>%
    mutate(some_college = ifelse(educ == 1 | is.na(educ), 0, 1)) %>%
    mutate(missing_college = ifelse(is.na(educ), 1, 0)) %>%
    mutate(any_prev_birth_issue = ifelse(prev_concern == 1 | prev_any_q_icd == 1 | prev_mis_still == 1, 1,0)) %>%
    mutate(inc_quartile_4 = ifelse(inc_quartile != 4 | is.na(inc_quartile), 0, 1)) %>%
    mutate(full_college = ifelse(educ != 3 | is.na(educ), 0, 1))


print(DATA)
pa <- 0.005
pfp <- 0.01
pfn <- 0.01

J <- 1 
#file.remove(paste0(TEMP, "/Prenatal/estimate_w_xs/estimates.csv"))
#file.remove(paste0(RESULTS, "/Prenatal/estimate_w_xs/estimates.csv"))


#### Load Necessary Model Functions ####

source("r_functions/production_model_xs_new.R")

#### ESTIMATE ####
print("Estimate on actual data")

# Calculate actual moments in data + their weights
actual <- moments_wgts(data=analysis)
data_moments <- actual[[1]]
weights <- actual[[2]]
rm(actual)
print("calculated actual moments")
data_moments_table <- data.frame(data_moments, weights)
write.csv(data_moments_table, file=paste0(RESULTS, "/estimate_w_xs/data_moments_table.csv"), na=".", row.names=TRUE)


# Define objective functions
target = function(x) obj(x, bootstrap_sample, bootstrap_data_moments, bootstrap_weights, f_a = f_a, f_c = f_c, x_vars = x_vars, unique_flag=1, info=FALSE)
g_target = function(x) nl.grad(x, target)


# Bounds for when imposing coefficients for XS = 0 (i.e. estimate without Xs)
lb = c(-300000, -300000, 1, 1, -0.9, 0.5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
ub = c(0, 0, 100000, 100000, 0.9, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)

divider <- c("####################")
divider <- t(divider)

# Wrapper function to use a global method first, follower by a local gradient method.
# alg1 and alg2 may be any algorithm string supported by NLOPTR.
doublesolve <- function(
  x0,
  eval_f,
  eval_grad=NA,
  lb,
  ub,
  alg1="NLOPT_GN_DIRECT_NOSCAL",
  alg2="NLOPT_LN_SBPLX"
) {
  o1 = list("maxeval"=10000, # used to be "maxtime"= 2419200, "maxeval" = 10000000
            "xtol_rel"=0,
            "algorithm"=alg1,
            "pop.size" = (length(x0) + 1) * 100, # If convergence is too slow, can
            # lower the 100 to ~10ish
            "print_level"=3, 
            "ranseed" = 4792
  )
  
  o2 = list("maxeval"=100000,
            "xtol_rel"=1.0e-8,
            "algorithm"=alg2,
            "print_level"=3,
            "ranseed" = 2837
  )
  set.seed(101)
  
  #write.table(divider, file=paste0(TEMP, "/Prenatal/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  #write.table(divider, file=paste0(RESULTS, "/Prenatal/estimate_w_xs/estimates.csv"), sep = ",", na=".", append = TRUE, col.names = FALSE, row.names = FALSE)
  s2 = nloptr(x0=par0, eval_f=eval_f, lb=lb, ub=ub, opts=o2)
  print("Results of Local optimization")
  print(s2)
  
  return(list(s2))
}

####### Normalize X's variables #############
analysis <- analysis %>%
  mutate(full_college_norm = (full_college - mean(full_college))) %>%
  mutate(fc_n = (full_college - mean(full_college))) %>%
  mutate(inc_quartile_4_norm = (inc_quartile_4 - mean(inc_quartile_4))) %>%
  mutate(married_norm = (married - mean(married))) %>%
  mutate(mom_foreign_norm = (mom_foreign - mean(mom_foreign))) %>%
  mutate(any_prev_birth_issue_norm = (any_prev_birth_issue - mean(any_prev_birth_issue))) %>%
  mutate(prev_kids_norm = (dv_prev_kids - mean(dv_prev_kids))) %>%
  mutate(age_35_plus_norm = (age_35_plus - mean(age_35_plus)))
f_a = ~ 0 + full_college_norm + inc_quartile_4_norm + married_norm + mom_foreign_norm + any_prev_birth_issue_norm + age_35_plus_norm + prev_kids_norm
f_c = ~ 0 + full_college_norm + inc_quartile_4_norm + married_norm + mom_foreign_norm + any_prev_birth_issue_norm + age_35_plus_norm + prev_kids_norm
x_vars <- c("full_college_norm", "inc_quartile_4_norm", "married_norm", "mom_foreign_norm", "any_prev_birth_issue_norm", "age_35_plus_norm", "prev_kids_norm")

###### Optimize ####
# To use parameters inclusive of mean shifting:
# Using f_a = ~ a_i + age
# and   f_c = ~ c_i + age
# Means TWO new parameters need to be added to the end of this 
# vector below. The 2.0 will be the coefficient on age in f_a,
# and the -3.0 will be the coefficient on age in f_c.

par0 <- read.csv(paste0(RESULTS, "/estimate_w_xs/estimates_global_no_xs.csv"), header = FALSE) %>%
  as.matrix()
par0 <- par0[,1]
par0 <- as.vector(par0)
par0 <- par0[-21]
print(par0)

N_bootstrap <- 100
bootstrap_results <- matrix(NA, nrow = N_bootstrap, ncol = 22)
for (x in 1:N_bootstrap) {
  seed <- x + 321
  set.seed(seed)
  smp_size <- as.numeric(nrow(analysis))
  bootstrap_index <- sample(1:nrow(analysis), size = smp_size, replace = TRUE)
  bootstrap_sample <- analysis[bootstrap_index,]
  
  actual <- moments_wgts(data=bootstrap_sample)
  bootstrap_data_moments <- actual[[1]]
  bootstrap_weights <- actual[[2]]
  rm(actual)
  print("calculated actual moments")
  bootstrap_data_moments_table <- data.frame(bootstrap_data_moments, bootstrap_weights)
  #write.csv(data_moments_table, file=paste0(RESULTS, "/Prenatal/estimate_w_xs/data_moments_table.csv"), na=".", row.names=TRUE)
  #if (Sys.info()[7] != "sean.gao") {
  #  file.copy(from=paste0(RESULTS, "/Prenatal/estimate_w_xs/data_moments_table.csv"),
  #            to=paste0(TRANSFER, "/Prenatal/estimate_w_xs/data_moments_table.csv"),
  #            overwrite=TRUE)
  #}
  if (x == 1) {
    bs_sample_each <- bootstrap_index
  } else {
    bs_sample_each <- rbind(bs_sample_each, bootstrap_index) 
  }
  print(nrow(bootstrap_sample))
  system.time(estimate <- doublesolve(par0, target, g_target, lb, ub))
  est <- estimate[[1]]
  results_boot <- est$solution
  opt_est <- est$solution
  results_boot <- append(results_boot, est$objective)
  results_boot <- append(results_boot, est$iterations)
  bootstrap_results[x,] <- results_boot 
  x <- x+1
  
  rm(bootstrap_index)
  rm(bootstrap_sample)
}

write.csv(bootstrap_results, file = paste0(RESULTS, "/estimate_w_xs/bootstrap_se.csv"), na=".", row.names=TRUE)

real_bs_table <- read.csv(paste0(RESULTS, "/estimate_w_xs/bootstrap_se.csv"), header = TRUE)
real_bs_table <- real_bs_table[,-1]
sd_mu_a <- sd(real_bs_table[,1])
sd_mu_c <- sd(real_bs_table[,2])
sd_sigma_a <- sd(real_bs_table[,3])
sd_sigma_c <- sd(real_bs_table[,4])
sd_rho <- sd(real_bs_table[,5])
sd_psi <- sd(real_bs_table[,6])
sd_vector <- c(sd_mu_a, sd_mu_c, sd_sigma_a, sd_sigma_c, sd_rho, sd_psi)
q_mu_a <- quantile(real_bs_table[,1], c(0.025, 0.975))
q_mu_c <- quantile(real_bs_table[,2], c(0.025, 0.975))
q_sigma_a <- quantile(real_bs_table[,3], c(0.025, 0.975))
q_sigma_c <- quantile(real_bs_table[,4], c(0.025, 0.975))
q_rho <- quantile(real_bs_table[,5], c(0.025, 0.975))
q_psi <- quantile(real_bs_table[,6], c(0.025, 0.975))
q_vector <- rbind(q_mu_a, q_mu_c, q_sigma_a, q_sigma_c, q_rho, q_psi)

mean_mu_a <- mean(real_bs_table[,1])
mean_mu_c <- mean(real_bs_table[,2])
mean_sigma_a <- mean(real_bs_table[,3])
mean_sigma_c <- mean(real_bs_table[,4])
mean_rho <- mean(real_bs_table[,5])
mean_psi <- mean(real_bs_table[,6])
mean_vector <- c(mean_mu_a, mean_mu_c, mean_sigma_a, mean_sigma_c, mean_rho, mean_psi)

med_mu_a <- median(real_bs_table[,1])
med_mu_c <- median(real_bs_table[,2])
med_sigma_a <- median(real_bs_table[,3])
med_sigma_c <- median(real_bs_table[,4])
med_rho <- median(real_bs_table[,5])
med_psi <- median(real_bs_table[,6])
median_vector <- c(med_mu_a, med_mu_c, med_sigma_a, med_sigma_c, med_rho, med_psi)
bootstrap_stats <- cbind(sd_vector, mean_vector, median_vector, q_vector)

write.csv(bootstrap_stats, file = paste0(RESULTS, "/estimate_w_xs/bootstrap_stats.csv"), na = ".", row.names=TRUE)

print("End time")
end_time <- Sys.time()

print(end_time)
time_diff <- end_time - start_time
print(time_diff)

#### EXIT ####

sink()